【踩坑篇】BN

tensorflow 中实现BN的方式有多重:

简介

目前的BN的操作均基于2015年google提出的《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》,但是在实际操作中发现有很多需要特别注意的点

  • 基于mini-batch进行训练,要保证训练和测试的数据的同分布
  • 不同batch的分布的稳定性

公式:y=γ(x-μ)/σ+β

tf.nn.batch_normalization

相关的参数说明:

  • x 输入数据
  • mean 样本均值
  • variance 样本标准差
  • offset 样本的偏倚
  • scale 样本的缩放

因为要考虑训练和测试数据的同质性,故在进行BN时操作是不同的

# 定义BN相关的占位符(给定维度size和权重分配decay)
# pop_mean和pop_vari为训练数据的整体情况的整合
gamma = tf.variable(tf.ones[size])
beta = tf.variable(tf.zeros[size])

pop_mean = tf.variable(tf.zeros[size], trainable=False)
pop_vari = tf.variable(tf.ones[size], trainable=False)

# 针对训练集,是按照batch的均值进行计算训练,故BN时为真实值+偏置后送入模型
batch_mean, batch_variance = tf.nn.moments(layer, [0])
train_mean = tf.assign(pop_mean, pop_mean*decay + batch_mean*(1-decay))
train_vari = tf.assign(pop_vari, pop_vari*decay + batch_vari*(1-decay))
bn_layer = tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma)

# 针对测试集
bn_layer = tf.nn.batch_normalization(layer, pop_mean, pop_vari, beta, gamma)

tf.layers.batch_normalization

相关参数说明:

  • inputs 输入数据
  • momentum 训练时整体和batch数据作用
  • training 执行类型

因为要考虑训练和测试数据的同质性,故在进行BN时操作是不同的

# 训练数据
bn_layer = tf.layers.batch_normalization(layer, momentum=decay, traning=True)

# 测试数据
bn_layer = tf.layers.batch_normalization(layer, traning=False)

results matching ""

    No results matching ""